{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install lightly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Note: The model and training settings do not follow the reference settings\n",
    "from the paper. The settings are chosen such that the example can easily be\n",
    "run on a small dataset with a single GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "import torchvision\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly.loss import VICRegLLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "## The global projection head is the same as the Barlow Twins one\n",
    "from lightly.models.modules import BarlowTwinsProjectionHead\n",
    "from lightly.models.modules.heads import VicRegLLocalProjectionHead\n",
    "from lightly.transforms.vicregl_transform import VICRegLTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class VICRegL(pl.LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        resnet = torchvision.models.resnet18()\n",
    "        self.backbone = nn.Sequential(*list(resnet.children())[:-2])\n",
    "        self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)\n",
    "        self.local_projection_head = VicRegLLocalProjectionHead(512, 128, 128)\n",
    "        self.average_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))\n",
    "        self.criterion = VICRegLLoss()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.backbone(x)\n",
    "        y = self.average_pool(x).flatten(start_dim=1)\n",
    "        z = self.projection_head(y)\n",
    "        y_local = x.permute(0, 2, 3, 1)  # (B, D, W, H) to (B, W, H, D)\n",
    "        z_local = self.local_projection_head(y_local)\n",
    "        return z, z_local\n",
    "\n",
    "    def training_step(self, batch, batch_index):\n",
    "        views_and_grids = batch[0]\n",
    "        views = views_and_grids[: len(views_and_grids) // 2]\n",
    "        grids = views_and_grids[len(views_and_grids) // 2 :]\n",
    "        features = [self.forward(view) for view in views]\n",
    "        loss = self.criterion(\n",
    "            global_view_features=features[:2],\n",
    "            global_view_grids=grids[:2],\n",
    "            local_view_features=features[2:],\n",
    "            local_view_grids=grids[2:],\n",
    "        )\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n",
    "        return optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = VICRegL()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = VICRegLTransform(n_local_views=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# we ignore object detection annotations by setting target_transform to return 0\n",
    "def target_transform(t):\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torchvision.datasets.VOCDetection(\n",
    "    \"datasets/pascal_voc\",\n",
    "    download=True,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    ")\n",
    "# or create a dataset from a folder containing images or videos:\n",
    "# dataset = LightlyDataset(\"path/to/folder\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=256,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n",
    "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n",
    "trainer = pl.Trainer(\n",
    "    max_epochs=10,\n",
    "    devices=\"auto\",\n",
    "    accelerator=\"gpu\",\n",
    "    strategy=\"ddp\",\n",
    "    sync_batchnorm=True,\n",
    "    use_distributed_sampler=True,  # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n",
    ")\n",
    "trainer.fit(model=model, train_dataloaders=dataloader)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
